{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "912cd8c6-d405-4dfe-8897-46108e6a6af7",
   "metadata": {},
   "source": [
    "# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "631b09a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: An OpenAI API key must be set here for application initialization, even if not in use.\n",
    "# If you're not utilizing OpenAI models, assign a placeholder string (e.g., \"not_used\").\n",
    "import os\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"your-openai-key\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d7d995-7beb-40b5-9a44-afd350b7d221",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cinderella story defined in sample.txt\n",
    "with open('demo/sample.txt', 'r') as file:\n",
    "    text = file.read()\n",
    "\n",
    "print(text[:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7d51ebd-5597-4fdd-8c37-32636395081b",
   "metadata": {},
   "source": [
    "1) **Building**: RAPTOR recursively embeds, clusters, and summarizes chunks of text to construct a tree with varying levels of summarization from the bottom up. You can create a tree from the text in 'sample.txt' using `RA.add_documents(text)`.\n",
    "\n",
    "2) **Querying**: At inference time, the RAPTOR model retrieves information from this tree, integrating data across lengthy documents at different abstraction levels. You can perform queries on the tree with `RA.answer_question`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4f58830-9004-48a4-b50e-61a855511d24",
   "metadata": {},
   "source": [
    "### Building the tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3753fcf9-0a8e-4ab3-bf3a-6be38ef6cd1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from raptor import RetrievalAugmentation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e843edf",
   "metadata": {},
   "outputs": [],
   "source": [
    "RA = RetrievalAugmentation()\n",
    "\n",
    "# construct the tree\n",
    "RA.add_documents(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f219d60a-1f0b-4cee-89eb-2ae026f13e63",
   "metadata": {},
   "source": [
    "### Querying from the tree\n",
    "\n",
    "```python\n",
    "question = # any question\n",
    "RA.answer_question(question)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4037c5-ad5a-424b-80e4-a67b8e00773b",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"How did Cinderella reach her happy ending ?\"\n",
    "\n",
    "answer = RA.answer_question(question=question)\n",
    "\n",
    "print(\"Answer: \", answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5be7e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the tree by calling RA.save(\"path/to/save\")\n",
    "SAVE_PATH = \"demo/cinderella\"\n",
    "RA.save(SAVE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e845de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load back the tree by passing it into RetrievalAugmentation\n",
    "\n",
    "RA = RetrievalAugmentation(tree=SAVE_PATH)\n",
    "\n",
    "answer = RA.answer_question(question=question)\n",
    "print(\"Answer: \", answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "277ab6ea-1c79-4ed1-97de-1c2e39d6db2e",
   "metadata": {},
   "source": [
    "## Using other Open Source Models for Summarization/QA/Embeddings\n",
    "\n",
    "If you want to use other models such as Llama or Mistral, you can very easily define your own models and use them with RAPTOR. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f86cbe7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from raptor import BaseSummarizationModel, BaseQAModel, BaseEmbeddingModel, RetrievalAugmentationConfig\n",
    "from transformers import AutoTokenizer, pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe5cef43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you want to use the Gemma, you will need to authenticate with HuggingFace, Skip this step, if you have the model already downloaded\n",
    "from huggingface_hub import login\n",
    "login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "245b91a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, pipeline\n",
    "import torch\n",
    "\n",
    "# You can define your own Summarization model by extending the base Summarization Class. \n",
    "class GEMMASummarizationModel(BaseSummarizationModel):\n",
    "    def __init__(self, model_name=\"google/gemma-2b-it\"):\n",
    "        # Initialize the tokenizer and the pipeline for the GEMMA model\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "        self.summarization_pipeline = pipeline(\n",
    "            \"text-generation\",\n",
    "            model=model_name,\n",
    "            model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
    "            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),  # Use \"cpu\" if CUDA is not available\n",
    "        )\n",
    "\n",
    "    def summarize(self, context, max_tokens=150):\n",
    "        # Format the prompt for summarization\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": f\"Write a summary of the following, including as many key details as possible: {context}:\"}\n",
    "        ]\n",
    "        \n",
    "        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "        \n",
    "        # Generate the summary using the pipeline\n",
    "        outputs = self.summarization_pipeline(\n",
    "            prompt,\n",
    "            max_new_tokens=max_tokens,\n",
    "            do_sample=True,\n",
    "            temperature=0.7,\n",
    "            top_k=50,\n",
    "            top_p=0.95\n",
    "        )\n",
    "        \n",
    "        # Extracting and returning the generated summary\n",
    "        summary = outputs[0][\"generated_text\"].strip()\n",
    "        return summary\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a171496d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GEMMAQAModel(BaseQAModel):\n",
    "    def __init__(self, model_name= \"google/gemma-2b-it\"):\n",
    "        # Initialize the tokenizer and the pipeline for the model\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "        self.qa_pipeline = pipeline(\n",
    "            \"text-generation\",\n",
    "            model=model_name,\n",
    "            model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
    "            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n",
    "        )\n",
    "\n",
    "    def answer_question(self, context, question):\n",
    "        # Apply the chat template for the context and question\n",
    "        messages=[\n",
    "              {\"role\": \"user\", \"content\": f\"Given Context: {context} Give the best full answer amongst the option to question {question}\"}\n",
    "        ]\n",
    "        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "        \n",
    "        # Generate the answer using the pipeline\n",
    "        outputs = self.qa_pipeline(\n",
    "            prompt,\n",
    "            max_new_tokens=256,\n",
    "            do_sample=True,\n",
    "            temperature=0.7,\n",
    "            top_k=50,\n",
    "            top_p=0.95\n",
    "        )\n",
    "        \n",
    "        # Extracting and returning the generated answer\n",
    "        answer = outputs[0][\"generated_text\"][len(prompt):]\n",
    "        return answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "878f7c7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "class SBertEmbeddingModel(BaseEmbeddingModel):\n",
    "    def __init__(self, model_name=\"sentence-transformers/multi-qa-mpnet-base-cos-v1\"):\n",
    "        self.model = SentenceTransformer(model_name)\n",
    "\n",
    "    def create_embedding(self, text):\n",
    "        return self.model.encode(text)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "255791ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "RAC = RetrievalAugmentationConfig(summarization_model=GEMMASummarizationModel(), qa_model=GEMMAQAModel(), embedding_model=SBertEmbeddingModel())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee46f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "RA = RetrievalAugmentation(config=RAC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe05daf",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('demo/sample.txt', 'r') as file:\n",
    "    text = file.read()\n",
    "    \n",
    "RA.add_documents(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eee5847",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"How did Cinderella reach her happy ending?\"\n",
    "\n",
    "answer = RA.answer_question(question=question)\n",
    "\n",
    "print(\"Answer: \", answer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "RAPTOR_env",
   "language": "python",
   "name": "raptor_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
